{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XIyP_0r6zuVc"
      },
      "source": [
        "# Training Large Language Models in 2bit with `aqlm`, `transformers` and `PEFT`\n",
        "\n",
        "<a target=\"_blank\" href=\"https://colab.research.google.com/github/Vahe1994/AQLM/blob/main/notebooks/aqlm_2bit_training.ipynb\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>\n",
        "\n",
        "Welcome to this notebook that goes through the recent `aqlm` integration that introduces minimal performance degradation 2bit quantization techniques.\n",
        "\n",
        "In this notebook, we will learn how to load a large model in 2bit (`Mixtral-8x7b`) and train it using Google Colab and PEFT library from Hugging Face 🤗.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A_VgSpl4Dsr3"
      },
      "source": [
        "**Install the `aqlm` library**\n",
        "- It's the only extra dependency to run AQLM models.\n",
        "- Add `[gpu]` to install the required CUDA specific dependencies.\n",
        "- Install the latest `accelerate` and `transformers` releases to properly support it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FuXIFTFapAMI"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!pip install aqlm[gpu]>=1.1.0\n",
        "!pip install git+https://github.com/huggingface/peft.git@main\n",
        "!pip install accelerate>=0.27.0\n",
        "!pip install git+https://github.com/huggingface/transformers.git@main\n",
        "!pip install datasets\n",
        "!pip install bitsandbytes # for 8-bit optimizer only"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MJ-5idQwzvg-"
      },
      "source": [
        "First let's load the model we are going to use - `Mixtral-8x7b`! Note that the model itself is around 50GB in half precision"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "E0Nl5mWL0k2T"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
        "\n",
        "model_id = \"ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf\"\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
        "model = AutoModelForCausalLM.from_pretrained(model_id, device_map=\"auto\", torch_dtype=\"auto\", low_cpu_mem_usage=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mp2gMi1ZzGET"
      },
      "source": [
        "**Add LoRA**\n",
        "\n",
        "To alter model's behavior, we have to make it trainable. We can do that by addind a small set of trainable parameters on top of the untrainable quantized ones."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Ybeyl20n3dYH",
        "outputId": "0efda156-4886-4718-9877-e93a17dc02d2"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "trainable params: 3,407,872 || all params: 6,550,261,760 || trainable%: 0.05202650099894634\n"
          ]
        }
      ],
      "source": [
        "from peft import LoraConfig, get_peft_model\n",
        "\n",
        "config = LoraConfig(\n",
        "    r=8,\n",
        "    lora_alpha=32,\n",
        "    target_modules=[\"q_prok\", \"k_proj\", \"o_proj\"],\n",
        "    lora_dropout=0.05,\n",
        "    bias=\"none\",\n",
        "    task_type=\"CAUSAL_LM\"\n",
        ")\n",
        "\n",
        "model = get_peft_model(model, config)\n",
        "model.print_trainable_parameters()\n",
        "model.enable_input_require_grads() # it's needed for gradient checkpointing"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4xSPH1D_Wv9x"
      },
      "source": [
        "Here we add a trainable adapter ontop of every `q_prok`, `k_proj` and `o_proj` linear layer."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FCc64bfnmd3j"
      },
      "source": [
        "**Loading a dataset**\n",
        "\n",
        "Let's load a common dataset, english quotes, to fine tune our model on famous quotes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s6f4z8EYmcJ6"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "data = load_dataset(\"Abirate/english_quotes\")\n",
        "data = data.map(lambda samples: tokenizer(samples[\"quote\"]), batched=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_0MOtwf3zdZp"
      },
      "source": [
        "Run the cell below to run the training! For the sake of the demo, we just ran it for few steps just to showcase how to use this integration with existing tools on the HF ecosystem."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 481
        },
        "id": "jq0nX33BmfaC",
        "outputId": "7f470980-c49e-4230-b947-ad43510f1bee"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='10' max='10' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [10/10 13:02, Epoch 0/1]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>2.042200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>1.293400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>1.447500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>1.433600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>1.725900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>1.506400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>7</td>\n",
              "      <td>1.549600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>8</td>\n",
              "      <td>1.038300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>9</td>\n",
              "      <td>1.603300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>1.676400</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "TrainOutput(global_step=10, training_loss=1.531658697128296, metrics={'train_runtime': 861.2678, 'train_samples_per_second': 0.046, 'train_steps_per_second': 0.012, 'total_flos': 56809829376000.0, 'train_loss': 1.531658697128296, 'epoch': 0.02})"
            ]
          },
          "execution_count": 6,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import transformers\n",
        "\n",
        "tokenizer.pad_token = tokenizer.eos_token\n",
        "\n",
        "trainer = transformers.Trainer(\n",
        "    model=model,\n",
        "    train_dataset=data[\"train\"],\n",
        "    args=transformers.TrainingArguments(\n",
        "        per_device_train_batch_size=1,\n",
        "        gradient_accumulation_steps=4,\n",
        "        gradient_checkpointing=True,\n",
        "        warmup_steps=2,\n",
        "        max_steps=10,\n",
        "        learning_rate=2e-4,\n",
        "        fp16=True,\n",
        "        logging_steps=1,\n",
        "        output_dir=\"outputs\",\n",
        "        optim=\"adamw_bnb_8bit\",\n",
        "        logging_first_step=True,\n",
        "    ),\n",
        "    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
        ")\n",
        "model.config.use_cache = False  # silence the warnings. Please re-enable for inference!\n",
        "trainer.train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "05iBmtP6X3Mq"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
